import argparse
import ipdb

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.transforms import SIGN

from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

from logger import Logger

import numpy as np
from sklearn.preprocessing import normalize

class MLP(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(MLP, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
            
        self.lins.append(torch.nn.Linear(hidden_channels,out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x):
        for lin in self.lins[:-1]:
            x = F.dropout(F.relu(lin(x)), p=0.5, training=self.training)
        x = self.lins[-1](x)
        return x
#         return torch.log_softmax(x, dim=-1)

class SIGN_model(torch.nn.Module):
    """
    The code is based on the implementation from DGL team.
    https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/sign/sign.py
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_hops, num_layers,dropout, input_drop):
        super(SIGN_model, self).__init__()
        self.dropout = torch.nn.Dropout(dropout)
        self.relu = torch.nn.ReLU()
        self.inception_ffs = torch.nn.ModuleList()
        self.input_drop = torch.nn.Dropout(input_drop)
        
        for hop in range(num_hops):
            self.inception_ffs.append(
                MLP(in_channels, hidden_channels, hidden_channels,
                    num_layers, dropout))
            
        self.project = MLP(num_hops * hidden_channels,
                           hidden_channels, out_channels,
                           num_layers, dropout)
        
    def forward(self,xs):
        xs = [self.input_drop(x) for x in xs]
        hidden = []
        for x, ff in zip(xs, self.inception_ffs):
            hidden.append(ff(x))
#         ipdb.set_trace()
        out = torch.log_softmax(self.project(self.dropout(self.relu(torch.cat(hidden, dim=-1)))), dim=-1)
        return out
    
    def reset_parameters(self):
        for ff in self.inception_ffs:
            ff.reset_parameters()
        self.project.reset_parameters()
    
    
def train(model, xs, y_true, optimizer, train_loader):
    model.train()
    device = y_true.device
    for batch in train_loader:
        optimizer.zero_grad()
        batch_xs = [x[batch].to(device) for x in xs]
        out = model(batch_xs)
        loss = F.nll_loss(out, y_true[batch].squeeze(1))
        loss.backward()
        optimizer.step()

    return loss.item()


@torch.no_grad()
def test(model, xs, y_true, evaluator):
    model.eval()
    
    y_preds = []
    loader = DataLoader(range(y_true.size(0)), batch_size=100000)
    for perm in loader:
        y_pred = model([x[perm] for x in xs]).argmax(dim=-1, keepdim=True)
        y_preds.append(y_pred.cpu())
    y_pred = torch.cat(y_preds, dim=0)

    return evaluator.eval({
        'y_true': y_true,
        'y_pred': y_pred,
    })['acc']


def main():
    parser = argparse.ArgumentParser(description='OGBN-Products (SIGN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=10)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--K', type=int, default=5)
    parser.add_argument('--hidden_channels', type=int, default=512)
    parser.add_argument('--dropout', type=float, default=0.4)
    parser.add_argument('--in_dropout', type=float, default=0.3)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument("--eval-every", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=25000)
    parser.add_argument("--eval-batch-size", type=int, default=100000,
                        help="evaluation batch size")
    parser.add_argument('--data_root_dir', type=str, default='../../dataset')
    parser.add_argument('--pretrain_path', type=str, default='None')
    parser.add_argument('--preprocess', type=str, default='None')
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygNodePropPredDataset(name='ogbn-products',
                                     root=args.data_root_dir)
    split_idx = dataset.get_idx_split()
    data = dataset[0]
    
    # Load Pretrained node features from PECOS
    if args.pretrain_path is not 'None':
        data.x = torch.tensor(np.load(args.pretrain_path))
        print("Pretrained node feature loaded! Path: {}".format(args.pretrain_path))
    
    
    
    if args.preprocess == 'Std':
        # Decide if we want to normalize along dim 0 or 1.
        X = data.x
        X = X-X.mean(dim=0,keepdim=True)
        X = X/torch.std(X, dim=0, keepdim=True)
        data.x = X
        print("Node features standardized!")
    elif args.preprocess == 'Norm':
        X = data.x.numpy()
        X = torch.tensor(normalize(X))
        data.x = X
        print("Node features normalized!")
    
    dataset[0].x = data.x
    data = SIGN(args.K)(dataset[0])  # This might take a while.

    xs = [data.x] + [data[f'x{i}'] for i in range(1, args.K + 1)]
    xs_train = [x[split_idx['train']].to(device) for x in xs]
    xs_valid = [x[split_idx['valid']].to(device) for x in xs]
    xs_test = [x[split_idx['test']].to(device) for x in xs]

    y_train_true = data.y[split_idx['train']].to(device)
    y_valid_true = data.y[split_idx['valid']].to(device)
    y_test_true = data.y[split_idx['test']].to(device)

    model = SIGN_model(data.x.size(-1),args.hidden_channels,dataset.num_classes,
                       args.K+1, args.num_layers,args.dropout,args.in_dropout).to(device)

    evaluator = Evaluator(name='ogbn-products')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
#         Prepare minibatches
        train_loader = torch.utils.data.DataLoader(
        split_idx['train'], batch_size=args.batch_size, shuffle=True, drop_last=False)
#         test_loader = torch.utils.data.DataLoader(
#         torch.arange(data.y.shape[0]), batch_size=args.eval_batch_size,
#         shuffle=False, drop_last=False)
        
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, xs_train, y_train_true, optimizer, train_loader)
            
            if epoch % args.eval_every == 0:
                train_acc = test(model, xs_train, y_train_true, evaluator)
                valid_acc = test(model, xs_valid, y_valid_true, evaluator)
                test_acc = test(model, xs_test, y_test_true, evaluator)
                result = (train_acc, valid_acc, test_acc)
                logger.add_result(run, result)

#             if epoch % args.log_steps == 0:
#                 train_acc, valid_acc, test_acc = result
#                 print(f'Run: {run + 1:02d}, '
#                       f'Epoch: {epoch:02d}, '
#                       f'Loss: {loss:.4f}, '
#                       f'Train: {100 * train_acc:.2f}%, '
#                       f'Valid: {100 * valid_acc:.2f}%, '
#                       f'Test: {100 * test_acc:.2f}%')

        logger.print_statistics(run)
    logger.print_statistics()


if __name__ == "__main__":
    main()
